Latent space encoding using LSTMs: Finding similar word context

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics.pairwise import euclidean_distances

import numpy as np
import pandas as pd
In [2]:
maxlen = 3
max_features = 1000

Let's prepare the encoding that the Keras dataloader uses, so we can encode input, and reverse the output:

In [3]:
word_to_id = keras.datasets.imdb.get_word_index()
word_to_id = {k:(v+3) for k,v in word_to_id.items()}
word_to_id["<PAD>"] = 0
word_to_id["<START>"] = 1
word_to_id["<UNK>"] = 2
word_to_id["<UNUSED>"] = 3
id_to_word = {value:key for key,value in word_to_id.items()}

Get the data

We load the data nad preprocess it so the LSTMs can process it. This also handles the padding, in case a review is shorter than the defined sequence length.

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
In [5]:
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)
In [6]:
x_train.shape
Out[6]:
(25000, 3)
In [7]:
x_train[0]
Out[7]:
array([ 19, 178,  32], dtype=int32)
In [8]:
all = []
for it in range(x_train.shape[0]):
    row = np.zeros((maxlen, max_features))
    for jt in range(maxlen):
        row[jt, x_train[it, jt]] = 1
    all.append(row * 1.0)
In [9]:
data_enc = np.array(all)
In [10]:
data_enc.shape
Out[10]:
(25000, 3, 1000)
In [11]:
np.argmax(data_enc[0], axis=1).tolist()
Out[11]:
[19, 178, 32]

Build the model

In [12]:
inputs = keras.Input(shape=(None,), dtype="int32")
x = layers.Embedding(max_features, 128)(inputs)
x = layers.Bidirectional(layers.LSTM(128))(x)
x = layers.BatchNormalization()(x)
encoded = layers.Dense(3)(x)
In [13]:
x = layers.Dense(3)(layers.RepeatVector(maxlen)(encoded))
x = layers.BatchNormalization()(x)
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)
decoded = layers.TimeDistributed(layers.Dense(max_features))(x)
decoded = layers.Softmax()(decoded)
In [14]:
model = keras.Model(inputs, decoded)
model.summary()
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding (Embedding)        (None, None, 128)         128000    
_________________________________________________________________
bidirectional (Bidirectional (None, 256)               263168    
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
dense (Dense)                (None, 3)                 771       
_________________________________________________________________
repeat_vector (RepeatVector) (None, 3, 3)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 3, 3)              12        
_________________________________________________________________
batch_normalization_1 (Batch (None, 3, 3)              12        
_________________________________________________________________
bidirectional_1 (Bidirection (None, 3, 256)            135168    
_________________________________________________________________
time_distributed (TimeDistri (None, 3, 1000)           257000    
_________________________________________________________________
softmax (Softmax)            (None, 3, 1000)           0         
=================================================================
Total params: 785,155
Trainable params: 784,637
Non-trainable params: 518
_________________________________________________________________
In [15]:
encoder = keras.Model(inputs, encoded)
decoder = keras.Model(inputs, decoded)
In [16]:
model.compile(optimizer='adam', loss='categorical_crossentropy')
In [17]:
model.fit(x_train, data_enc, epochs=50)
Epoch 1/50
782/782 [==============================] - 7s 9ms/step - loss: 3.9955
Epoch 2/50
782/782 [==============================] - 7s 9ms/step - loss: 3.2858
Epoch 3/50
782/782 [==============================] - 7s 9ms/step - loss: 3.0419
Epoch 4/50
782/782 [==============================] - 7s 9ms/step - loss: 2.9094
Epoch 5/50
782/782 [==============================] - 7s 9ms/step - loss: 2.8059
Epoch 6/50
782/782 [==============================] - 7s 9ms/step - loss: 2.7413
Epoch 7/50
782/782 [==============================] - 7s 9ms/step - loss: 2.6699
Epoch 8/50
782/782 [==============================] - 7s 9ms/step - loss: 2.6317
Epoch 9/50
782/782 [==============================] - 8s 10ms/step - loss: 2.5801
Epoch 10/50
782/782 [==============================] - 12s 16ms/step - loss: 2.5328
Epoch 11/50
782/782 [==============================] - 16s 20ms/step - loss: 2.4928
Epoch 12/50
782/782 [==============================] - 12s 15ms/step - loss: 2.4553
Epoch 13/50
782/782 [==============================] - 8s 11ms/step - loss: 2.4133
Epoch 14/50
782/782 [==============================] - 7s 9ms/step - loss: 2.4140
Epoch 15/50
782/782 [==============================] - 7s 9ms/step - loss: 2.3683
Epoch 16/50
782/782 [==============================] - 7s 9ms/step - loss: 2.3458
Epoch 17/50
782/782 [==============================] - 7s 9ms/step - loss: 2.3270
Epoch 18/50
782/782 [==============================] - 8s 10ms/step - loss: 2.2790
Epoch 19/50
782/782 [==============================] - 8s 11ms/step - loss: 2.2486
Epoch 20/50
782/782 [==============================] - 8s 11ms/step - loss: 2.2293
Epoch 21/50
782/782 [==============================] - 9s 11ms/step - loss: 2.2228
Epoch 22/50
782/782 [==============================] - 9s 11ms/step - loss: 2.1896
Epoch 23/50
782/782 [==============================] - 9s 12ms/step - loss: 2.1791
Epoch 24/50
782/782 [==============================] - 9s 12ms/step - loss: 2.1519
Epoch 25/50
782/782 [==============================] - 9s 11ms/step - loss: 2.1203
Epoch 26/50
782/782 [==============================] - 9s 12ms/step - loss: 2.0964
Epoch 27/50
782/782 [==============================] - 10s 13ms/step - loss: 2.0785
Epoch 28/50
782/782 [==============================] - 9s 12ms/step - loss: 2.0771
Epoch 29/50
782/782 [==============================] - 8s 10ms/step - loss: 2.0662
Epoch 30/50
782/782 [==============================] - 7s 9ms/step - loss: 2.0460
Epoch 31/50
782/782 [==============================] - 7s 9ms/step - loss: 2.0257
Epoch 32/50
782/782 [==============================] - 7s 9ms/step - loss: 2.0201
Epoch 33/50
782/782 [==============================] - 8s 10ms/step - loss: 1.9941
Epoch 34/50
782/782 [==============================] - 8s 10ms/step - loss: 1.9944
Epoch 35/50
782/782 [==============================] - 8s 11ms/step - loss: 1.9602
Epoch 36/50
782/782 [==============================] - 9s 11ms/step - loss: 1.9452
Epoch 37/50
782/782 [==============================] - 9s 11ms/step - loss: 1.9474
Epoch 38/50
782/782 [==============================] - 9s 11ms/step - loss: 1.9218
Epoch 39/50
782/782 [==============================] - 9s 11ms/step - loss: 1.9267
Epoch 40/50
782/782 [==============================] - 9s 11ms/step - loss: 1.9189
Epoch 41/50
782/782 [==============================] - 9s 11ms/step - loss: 1.9018
Epoch 42/50
782/782 [==============================] - 8s 11ms/step - loss: 1.8768
Epoch 43/50
782/782 [==============================] - 8s 10ms/step - loss: 1.8753
Epoch 44/50
782/782 [==============================] - 8s 10ms/step - loss: 1.8558
Epoch 45/50
782/782 [==============================] - 9s 12ms/step - loss: 1.8645
Epoch 46/50
782/782 [==============================] - 10s 12ms/step - loss: 1.8537
Epoch 47/50
782/782 [==============================] - 9s 11ms/step - loss: 1.8287
Epoch 48/50
782/782 [==============================] - 9s 11ms/step - loss: 1.8363
Epoch 49/50
782/782 [==============================] - 8s 11ms/step - loss: 1.8081
Epoch 50/50
782/782 [==============================] - 8s 10ms/step - loss: 1.8121
Out[17]:
<tensorflow.python.keras.callbacks.History at 0x7f9fc969ac90>
In [18]:
res = model.predict(x_test)
In [19]:
res.shape
Out[19]:
(25000, 3, 1000)
In [20]:
np.argmax(data_enc[100], axis=1)
Out[20]:
array([ 46,   7, 158])
In [21]:
np.argmax(res[100], axis=1)
Out[21]:
array([  2,  42, 358])
In [22]:
x_test[100]
Out[22]:
array([  2, 385,  39], dtype=int32)
In [23]:
res[100]
Out[23]:
array([[2.5847822e-13, 2.8779176e-13, 9.9815732e-01, ..., 2.0068691e-10,
        2.6696442e-11, 3.0582864e-08],
       [1.2270926e-10, 1.4423923e-10, 5.3281087e-02, ..., 3.2689597e-06,
        1.3396644e-08, 2.1970925e-06],
       [2.8203293e-09, 2.5865201e-09, 1.2598239e-02, ..., 2.2434969e-07,
        2.5665216e-07, 1.3000058e-06]], dtype=float32)

Diplaying the latent space

In [24]:
import plotly.graph_objects as go
In [25]:
enc = encoder.predict(x_train)
In [26]:
enc[0]
Out[26]:
array([-1.6785883,  0.6114942, -1.4339578], dtype=float32)

Display by sentiment

In [27]:
plot_data = [[], []]
In [28]:
for it in range(y_train.shape[0]):
    plot_data[y_train[it]].append(enc[it])
In [29]:
fig = go.Figure([
    go.Scatter3d(
        x=np.array(plot_data[0])[:, 0],
        y=np.array(plot_data[0])[:, 1],
        z=np.array(plot_data[0])[:, 2],
        mode='markers',
        marker={'size': 1},
    ),
    go.Scatter3d(
        x=np.array(plot_data[1])[:, 0],
        y=np.array(plot_data[1])[:, 1],
        z=np.array(plot_data[1])[:, 2],
        mode='markers',
        marker={'size': 1},
    )
])
fig.write_html('plot_no_sentiment.html')
In [30]:
fig.show()

See the attached HTML file to explore the plot. When exporting Plotly usually does not work anymore.

Let's test it out...

First we create the input sequence that we want to run the model against.

In [47]:
testing = [word_to_id[word] for word in ['great', 'family', 'movie']]
In [48]:
testing
Out[48]:
[87, 223, 20]

Let's run the encoder and see where our sequence falls in the latent space.

In [49]:
testing_space = encoder.predict(np.array([testing]))
In [50]:
testing_space
Out[50]:
array([[ -0.75202096, -11.560089  ,   8.997971  ]], dtype=float32)

We can use the euclidean distance to figure out which dataset in the latent space is closest to what we just used as an input sequence. After we sort the array by its distances we should see the closest ones appear at the top:

In [51]:
full_distances = euclidean_distances(np.array([testing_space[0]]), enc)
In [52]:
distances = np.stack(
    [
        full_distances.reshape(25000, 1),
        np.array([[it] for it in range(25000)]),
    ],
    axis=1
).reshape(25000, 2).tolist()
distances = sorted(distances, key=lambda x: x[0])
In [53]:
distances[:10]
Out[53]:
[[0.24382853507995605, 10455.0],
 [0.5178334712982178, 17980.0],
 [0.9123451113700867, 8525.0],
 [0.9339559674263, 1228.0],
 [0.9391999244689941, 334.0],
 [0.9676989912986755, 5195.0],
 [0.9959375262260437, 962.0],
 [0.996581494808197, 15999.0],
 [0.9965817928314209, 690.0],
 [0.9965817928314209, 1258.0]]

Now we can display all the queries that the model encoded as "similar" to our input query.

In [54]:
for dist, key in distances[:20]:
    key = int(key)
    print(key, dist, [id_to_word[idx] for idx in x_train[key]])
10455 0.24382853507995605 ['decent', 'family', 'movie']
17980 0.5178334712982178 ['long', 'long', 'time']
8525 0.9123451113700867 ['long', 'sad', 'movie']
1228 0.9339559674263 ['great', 'dumb', 'movie']
334 0.9391999244689941 ['nice', 'long', 'look']
5195 0.9676989912986755 ['wonderful', 'thing', 'ever']
962 0.9959375262260437 ['an', 'excellent', 'movie']
15999 0.996581494808197 ['worth', 'your', 'time']
690 0.9965817928314209 ['worth', 'your', 'time']
1258 0.9965817928314209 ['worth', 'your', 'time']
1363 0.9965817928314209 ['worth', 'your', 'time']
7133 0.9965817928314209 ['worth', 'your', 'time']
9480 0.9965817928314209 ['worth', 'your', 'time']
11003 0.9965817928314209 ['worth', 'your', 'time']
16976 0.9965817928314209 ['worth', 'your', 'time']
18759 0.9965817928314209 ['worth', 'your', 'time']
24077 0.9965817928314209 ['worth', 'your', 'time']
5348 1.0082206726074219 ['great', '<UNK>', 'movie']
22993 1.018782377243042 ['fun', 'family', 'movie']
4205 1.0422136783599854 ['an', '<UNK>', 'movie']